"""
compute_Amu.py
~~~~~~~~~~~~~~~

Compute the gauge potential ``A_μ(i)`` for each link ``i`` in a discrete
lattice gauge theory.  The potential is constructed from three pieces:

1.  A per‑link flip count ``n_i`` produced by the companion flip count
    simulator (Volume 1).  Flip counts quantify the number of times a
    context‑dependent operator changes at each link.
2.  A kernel value ``ρ_i`` provided by the Volume 3 diagnostics package.
    Errors if missing—ensures real kernel usage.
3.  A logistic fractal dimension ``D(n_i)`` and linear pivot weight ``g(D)``.

The logistic dimension is defined by

    D(n) = 1 + 2 / (1 + exp(k * (n - n0)))

where ``k`` and ``n0`` are parameters.  The pivot weight is a linear
function of ``D``: ``g(D) = a * D + b``.  The gauge potential for the
U(1) group is then

    A_i = g * g(D(n_i)) * ρ_i

where ``g`` is the gauge coupling constant.  For SU(2) and SU(3) the
matrix‑valued kernels are multiplied by the same scalar ``g(D(n_i))``.

The results are saved into ``data_dir`` as ``A_U1.npy``, ``A_SU2.npy`` and
``A_SU3.npy`` if the corresponding kernels exist.  Missing kernels raise errors.
"""

from __future__ import annotations

import os
import yaml
import numpy as np
from typing import Optional


def logistic_dimension(n: np.ndarray, k: float, n0: float) -> np.ndarray:
    """Compute the logistic fractal dimension for an array of flip counts."""
    return 1 + 2 / (1 + np.exp(k * (n - n0)))


def linear_gD(D: np.ndarray, a: float, b: float) -> np.ndarray:
    """Compute the linear pivot weight g(D) = a*D + b."""
    return a * D + b


def main(cfg_path: str) -> None:
    """Compute A_mu for each gauge group listed in the configuration."""
    with open(cfg_path) as f:
        config = yaml.safe_load(f)

    # Resolve data_dir relative to the directory of the config file
    base_dir = os.path.dirname(os.path.abspath(cfg_path))
    data_dir_cfg = config.get('data_dir', 'data')
    if os.path.isabs(data_dir_cfg):
        data_dir = data_dir_cfg
    else:
        data_dir = os.path.join(base_dir, data_dir_cfg)
    os.makedirs(data_dir, exist_ok=True)

    # Load lattice links to determine number of links
    lattice_path = os.path.join(data_dir, 'lattice.npy')
    lattice = np.load(lattice_path, allow_pickle=True)
    num_links = len(lattice)

    # Load flip counts if provided; default to ones if missing
    flip_counts_path = config.get('flip_counts_path', None)
    if flip_counts_path:
        if not os.path.isabs(flip_counts_path):
            flip_counts_path = os.path.normpath(os.path.join(base_dir, flip_counts_path))
        n = np.load(flip_counts_path)
    else:
        n = np.ones(num_links)  # Fallback, but recommend real flips

    # Pivot parameters
    a = config.get('pivot_a', 1.0)
    b = config.get('pivot_b', 0.0)
    k = config.get('logistic_k', 0.1)
    n0 = config.get('logistic_n0', 5.0)

    # Gauge coupling constant g
    g_coupling = config.get('g', 1.0)

    # Compute D(n) and g(D)
    D = logistic_dimension(n, k, n0)
    gD = linear_gD(D, a, b)

    # Load kernels for each gauge group; error if missing
    kernel_u1_path = os.path.join(data_dir, 'kernel_U1.npy')
    if not os.path.exists(kernel_u1_path):
        raise FileNotFoundError(f"Real kernel for U1 missing at {kernel_u1_path}")
    K_u1 = np.load(kernel_u1_path, allow_pickle=True)

    kernel_su2_path = os.path.join(data_dir, 'kernel_SU2.npy')
    if not os.path.exists(kernel_su2_path):
        raise FileNotFoundError(f"Real kernel for SU2 missing at {kernel_su2_path}")
    K_su2 = np.load(kernel_su2_path, allow_pickle=True)

    kernel_su3_path = os.path.join(data_dir, 'kernel_SU3.npy')
    if not os.path.exists(kernel_su3_path):
        raise FileNotFoundError(f"Real kernel for SU3 missing at {kernel_su3_path}")
    K_su3 = np.load(kernel_su3_path, allow_pickle=True)

    # Compute A_mu for U1
    A_u1 = g_coupling * gD * K_u1
    np.save(os.path.join(data_dir, 'A_U1.npy'), A_u1)
    print(f'Computed A_mu for U1, saved to {os.path.join(data_dir, "A_U1.npy")}')

    # Compute A_mu for SU2
    A_su2 = g_coupling * gD[:, np.newaxis, np.newaxis] * K_su2
    np.save(os.path.join(data_dir, 'A_SU2.npy'), A_su2)
    print(f'Computed A_mu for SU2, saved to {os.path.join(data_dir, "A_SU2.npy")}')

    # Compute A_mu for SU3
    A_su3 = g_coupling * gD[:, np.newaxis, np.newaxis] * K_su3
    np.save(os.path.join(data_dir, 'A_SU3.npy'), A_su3)
    print(f'Computed A_mu for SU3, saved to {os.path.join(data_dir, "A_SU3.npy")}')


if __name__ == '__main__':
    import sys
    cfg = sys.argv[1] if len(sys.argv) > 1 else 'config.yaml'
    main(cfg)